import argparse
import torch
import os
import re
import json
from tqdm import tqdm
import random
import requests
from io import BytesIO
import torch.multiprocessing as mp
from torch.multiprocessing import Process, Manager
import time

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

from tci_attn import LlamaAttentionWithLogits
from PIL import Image


def load_image(image_file: str) -> Image.Image:
    if image_file.startswith("http://") or image_file.startswith("https://"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def load_images(image_files: list[str]) -> list[Image.Image]:
    out = []
    for image_file in image_files:
        img = load_image(image_file)
        out.append(img)
    return out

def image_parser(image_arg: str, sep: str) -> list[str]:
    return image_arg.split(sep)

def eval_model(rank, args_dict, shared_results, gpu_times):

    visible_device = args_dict['gpu_ids'][rank]
    os.environ["CUDA_VISIBLE_DEVICES"] = str(visible_device)
    device = torch.device(f"cuda:0")
    
    disable_torch_init()

    # 1. load model
    model_name = get_model_name_from_path(args_dict['model_path'])
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path=args_dict['model_path'],
        model_base=args_dict['model_base'],
        model_name=model_name,
        torch_dtype=torch.float16 if 'torch_dtype' in args_dict else torch.float16,
    )
    model.to(device)

    total_gpus = len(args_dict['gpu_ids'])
    data_shard = args_dict['data'][rank::total_gpus]
    
    results = []

    # record start time
    start_time = time.time()

    # reset attention modules in model 
    if args_dict['tci'] == True:
        for i, layer in enumerate(model.model.layers):
            if i in [0, 1, 14, 15, 17]:    
                attn_adap = LlamaAttentionWithLogits(layer.self_attn.config, layer_idx=i, alpha=args_dict['alpha'])
                attn_adap.load_state_dict(layer.self_attn.state_dict())
                attn_adap = attn_adap.half().to(device)
                layer.self_attn = attn_adap

    for data in tqdm(data_shard, desc=f"GPU {visible_device} Processing"):
        image_file = os.path.join(args_dict['image_folder'], data["image"])
        prompt = "Describe the image in detail."
        image = load_image(image_file)

        images_tensor = process_images([image], image_processor, model.config).to(
            device, dtype=torch.float16
        )
        
        qs = prompt
        image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
        if IMAGE_PLACEHOLDER in qs:
            if model.config.mm_use_im_start_end:
                qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
            else:
                qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
        else:
            if model.config.mm_use_im_start_end:
                qs = image_token_se + "\n" + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

        conv_mode = args_dict['conv_mode']

        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        
        input_ids = (
            tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
            .unsqueeze(0)
            .to(device)
        )

        # generate
        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=images_tensor,
                image_sizes=[image.size],
                do_sample=True if args_dict['temperature'] > 0 else False,
                temperature=args_dict['temperature'],
                top_p=args_dict['top_p'],
                top_k=args_dict['top_k'],
                num_beams=args_dict['num_beams'],
                max_new_tokens=args_dict['max_new_tokens'],
                use_cache=True,
            )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

        result = {
            "image": data["image"],
            "original_description": outputs
        }
        results.append(result)

    # record end time    
    end_time = time.time()
    elapsed_time = end_time - start_time

    shared_results.extend(results)
    gpu_times[rank] = {
        'gpu_id': visible_device,
        'samples': len(results),
        'time': elapsed_time
    }
    print(f"GPU {visible_device} finished processing {len(results)} samples in {elapsed_time:.2f}")

def process_seed(seed, image_folder, num_samples, base_args, output_folder):
    """
    infer from image_folder with seeds
    """
    # set random seeds
    random.seed(seed)
    
    start_time = time.time()
    
    image_files = []
    for file in os.listdir(image_folder):
        if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp')):
            image_files.append(file)
    
    if len(image_files) < num_samples:
        print(f"WARNING: only {len(image_files)} images in folder, will use all instead of sampling {num_samples} images")
        selected_images = image_files
    else:
        selected_images = random.sample(image_files, num_samples)
    
    data = [{"image": img} for img in selected_images]
    
    args_dict = base_args.copy()
    args_dict['data'] = data
    
    mp.set_start_method('spawn', force=True)
    with Manager() as manager:
        shared_results = manager.list()

        gpu_times = manager.dict()    

        processes = []
        for rank in range(len(args_dict['gpu_ids'])):
            p = Process(target=eval_model, args=(rank, args_dict, shared_results,gpu_times))
            p.start()
            processes.append(p)
        
        for p in processes:
            p.join()
        
        final_results = list(shared_results)

        gpu_times = dict(gpu_times)

        total_compute_time = sum(gpu['time'] for gpu in gpu_times.values())
        
        output_file = os.path.join(output_folder, f"reslult_w+{args_dict['alpha']}_seed{seed}_multigpu.json")
        with open(output_file, "w") as f:
            json.dump(final_results, f, indent=4)
    
    end_time = time.time()
    wall_clock_time = end_time - start_time
    
    print(f"\n seed {seed} GPU time in total:")
    for gpu_info in gpu_times.values():
        print(f"  GPU {gpu_info['gpu_id']}: {gpu_info['samples']} images, {gpu_info['time']:.2f} s")
    
    print(f"seed {seed} total compute time: {total_compute_time:.2f} s")
    print(f"avg compute time: {total_compute_time/len(selected_images):.4f} s")
    
    return output_file, total_compute_time, wall_clock_time, len(selected_images)

def main():
    model_path = "llava-v1.5-7b"
    image_folder = "COCO_val2014"  
    output_folder = "tci/results/chair"
    num_samples = 500 
    
    # 5 seeds
    seeds = [42, 1, 3, 5, 9]
    
    gpu_ids = [0,1,2,3,4,5,6,7]  

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    base_args = {
        'model_path': model_path,
        'model_base': None,
        'conv_mode': "vicuna_v1",
        'image_folder': image_folder,
        'output_folder': output_folder,
        'sep': ",",
        'temperature': 1.0,
        'top_p': 1, # None
        'top_k': None,
        'num_beams': 1,
        'max_new_tokens': 512,
        'torch_dtype': torch.float16,
        'tci': False,
        'alpha': 4,
        'gpu_ids': gpu_ids
    }
    
    total_compute_time = 0
    total_wall_clock_time = 0
    total_images = 0
    
    # process one seed
    output_files = []
    for seed in seeds:
        print(f"\n process with {seed} ...")
        output_file, compute_time, wall_time, image_count = process_seed(seed, image_folder, num_samples, base_args, output_folder)
        output_files.append(output_file)
        
        total_compute_time += compute_time
        total_wall_clock_time += wall_time
        total_images += image_count
    
    # compute avg time
    avg_compute_time_per_image = total_compute_time / total_images if total_images > 0 else 0
    
    print("\n===== ALL DONE! =====")
    for seed, output_file in zip(seeds, output_files):
        print(f"seed {seed}: results saved in {os.path.basename(output_file)}")
    
    print(f"\n total compute time: {total_compute_time:.2f} s")
    print(f"total process images: {total_images}")
    print(f"avg compute time: {avg_compute_time_per_image:.4f} s")

if __name__ == "__main__":
    main()